package com.didiglobal.common.db.dbproxy;

import com.didiglobal.common.db.ForwardingConnection;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.regex.Pattern;

/**
 * @author liyanling
 * @date 2021/11/23 4:11 下午
 */
public class DbProxyConnection extends ForwardingConnection {

    private final static Logger LOGGER = LoggerFactory.getLogger(DbProxyConnection.class);

    private static final String SHADOW_TABLE_HINT = "shadow";

    private static final String USE_MASTER_HINT = "m";

    private static final String STATEMENT_COMMENT_TMPL = "/* %s */ ";

    private final static Pattern DML_PATTERN =
            Pattern.compile("^\\s*(select|update|insert|delete).*", Pattern.CASE_INSENSITIVE);

    private ConnectionProperties connectionProps = new ConnectionProperties();

    private String statementComment = null;

    public static DbProxyConnection wrap(Connection connection) {
        if (null == connection) {
            throw new IllegalArgumentException("Connection cannot be null");
        }
        if (connection instanceof DbProxyConnection) {
            return (DbProxyConnection) connection;
        }
        return new DbProxyConnection(connection).buildStatementComment();
    }

    public DbProxyConnection(Connection delegate) {
        super(delegate);
    }

    public DbProxyConnection useMaster(boolean useMaster) {
        connectionProps.router = (useMaster ? USE_MASTER_HINT : null);
        return this;
    }

    public DbProxyConnection useShadowTable(boolean useShadowTable) {
        connectionProps.mode = (useShadowTable ? SHADOW_TABLE_HINT : null);
        return this;
    }

    public DbProxyConnection autoSharding(String shardingValue) {
        if (null == shardingValue || "".equals(shardingValue.trim())) {
            connectionProps.pid = null;
        } else {
            connectionProps.pid = shardingValue.trim();
        }

        return this;
    }

    public DbProxyConnection isTx(boolean isTx) {
        connectionProps.isTx = isTx;
        return this;
    }

    public DbProxyConnection buildStatementComment() {
        String encodedStatementComment = connectionProps.encodeToStatementComment();
        if (null == encodedStatementComment) {
            statementComment = null;
        } else {
            statementComment = String.format(STATEMENT_COMMENT_TMPL, encodedStatementComment);
        }
        return this;
    }

    public String wrapWithStatementComment(String sql) {
        String sqlWithProxyComment = sql;
        if (null != statementComment) {
            if (DML_PATTERN.matcher(sql).matches()) {
                sqlWithProxyComment = statementComment + sql;
            }
        }
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("sqlWithProxyComment={}", sqlWithProxyComment);
        }
        return sqlWithProxyComment;
    }

    @Override
    public PreparedStatement prepareStatement(String sql) throws SQLException {
        String sqlWithProxyComment = wrapWithStatementComment(sql);
        PreparedStatement preparedStatement = super.prepareStatement(sqlWithProxyComment);
        return new DbProxyPreparedStatement(preparedStatement, this);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency)
            throws SQLException {
        String sqlWithProxyComment = wrapWithStatementComment(sql);
        PreparedStatement preparedStatement =
                super.prepareStatement(sqlWithProxyComment, resultSetType, resultSetConcurrency);
        return new DbProxyPreparedStatement(preparedStatement, this);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency,
                                              int resultSetHoldability) throws SQLException {
        String sqlWithProxyComment = wrapWithStatementComment(sql);
        PreparedStatement preparedStatement =
                super.prepareStatement(sqlWithProxyComment, resultSetType, resultSetConcurrency, resultSetHoldability);
        return new DbProxyPreparedStatement(preparedStatement, this);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
        String sqlWithProxyComment = wrapWithStatementComment(sql);
        PreparedStatement preparedStatement = super.prepareStatement(sqlWithProxyComment, autoGeneratedKeys);
        return new DbProxyPreparedStatement(preparedStatement, this);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
        String sqlWithProxyComment = wrapWithStatementComment(sql);
        PreparedStatement preparedStatement = super.prepareStatement(sqlWithProxyComment, columnIndexes);
        return new DbProxyPreparedStatement(preparedStatement, this);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
        String sqlWithProxyComment = wrapWithStatementComment(sql);
        PreparedStatement preparedStatement = super.prepareStatement(sqlWithProxyComment, columnNames);
        return new DbProxyPreparedStatement(preparedStatement, this);
    }

    @Override
    public Statement createStatement() throws SQLException {
        Statement statement = super.createStatement();
        return new DbProxyStatement(statement, this);
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
        Statement statement = super.createStatement(resultSetType, resultSetConcurrency);
        return new DbProxyStatement(statement, this);
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability)
            throws SQLException {
        Statement statement = super.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
        return new DbProxyStatement(statement, this);
    }

    @JsonInclude(JsonInclude.Include.NON_NULL)
    private static class ConnectionProperties {

        private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper()
                .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
                .configure(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES, false)
                .configure(DeserializationFeature.FAIL_ON_INVALID_SUBTYPE, false)
                .setSerializationInclusion(JsonInclude.Include.NON_NULL)
                .configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);

        private String router = USE_MASTER_HINT;

        private String mode = null;

        private String pid = null;

        private boolean isTx = false;

        public String getRouter() {
            return router;
        }

        public String getMode() {
            return mode;
        }

        public String getPid() {
            return pid;
        }

        @JsonProperty("is_tx")
        public Boolean isTx() {
            return isTx;
        }

        private boolean hasStatementComment() {
            if (null != router) {
                return true;
            }
            if (null != mode) {
                return true;
            }
            if (null != pid) {
                return true;
            }
            return false;
        }

        String encodeToStatementComment() {
            if (!hasStatementComment()) {
                return null;
            }

            try {
                return OBJECT_MAPPER.writeValueAsString(this);
            } catch (JsonProcessingException e) {
                String msg = String.format("encode statement comment failed, comment=%s", this);
                throw new IllegalStateException(msg, e);
            }
        }
    }
}